{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0c93a5cf-da1e-4df0-8758-fed132d288a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from mmcv import load, dump\n",
    "from pyskl.smp import *\n",
    "from pyskl.models import build_model\n",
    "from time import time\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b46e18ec-7657-438a-b1ae-eda33d4116a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "layout = 'nturgb+d'\n",
    "graph = dict(layout=layout, mode='spatial')\n",
    "graph_rdm = dict(layout=layout, mode='random', num_filter=8, init_off=.04, init_std=.02)\n",
    "graph_bin = dict(layout=layout, mode='binary_adj')\n",
    "aagcn_cfg = dict(type='AAGCN', graph_cfg=graph)\n",
    "ctrgcn_cfg = dict(type='CTRGCN', graph_cfg=graph)\n",
    "dgstgcn_cfg = dict(type='DGSTGCN', gcn_ratio=0.125, gcn_ctr='T', gcn_ada='T', tcn_ms_cfg=[(3, 1), (3, 2), (3, 3), (3, 4), ('max', 3), '1x1'], graph_cfg=graph_rdm)\n",
    "msg3d_cfg = dict(type='MSG3D', graph_cfg=graph_bin)\n",
    "stgcn_cfg = dict(type='STGCN', graph_cfg=graph)\n",
    "stgcnpp_cfg = dict(type='STGCN', gcn_adaptive='init', gcn_with_res=True, tcn_type='mstcn', graph_cfg=graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "290afcd7-8718-4a77-bbca-fde6244d0362",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg_map = dict(aagcn=aagcn_cfg, ctrgcn=ctrgcn_cfg, dgstgcn=dgstgcn_cfg, msg3d=msg3d_cfg, stgcn=stgcn_cfg, stgcnpp=stgcnpp_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4d2f0e8b-f80c-49c5-9309-89ef768efd18",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = 16\n",
    "warmup = 10\n",
    "iters = 10\n",
    "num_joints = {'nturgb+d': 25, 'coco': 17, 'openpose': 18}[layout]\n",
    "num_person = 2\n",
    "seq_len = 100"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3e8138df",
   "metadata": {},
   "source": [
    "- Inference speed tested with above settings (FPS -> sample / second)\n",
    "\n",
    "| Device / FPS | AAGCN | CTRGCN | DGSTGCN | MSG3D | STGCN | STGCN++ | PoseC3D | Notes |\n",
    "|--|--|--|--|--|--|--|--|--|\n",
    "| 2080Ti (Linux) | 274 | 353 | 409 | 111 | 518 | 476 | 41 | Iters 100, Warmup 100 |\n",
    "| 3060 (Windows) | 135 | 148 | 181 | 50 | 206 | 186 | 38 | Iters 100, Warmup 100 |\n",
    "| CPU (11800H Windows) | 7.5 | 7.2 | 8.1 | 4.3 | 13.8 | 8.9 | 2.9 | Iters 10, Warmup 10 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e37487c8-6182-40ec-a69f-39b53d2ffe18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aagcn FPS: 7.512336739640367\n",
      "ctrgcn FPS: 7.161762537141133\n",
      "dgstgcn FPS: 8.095086634695258\n",
      "msg3d FPS: 4.3017158277906375\n",
      "stgcn FPS: 13.779976396437187\n",
      "stgcnpp FPS: 8.861014139175074\n"
     ]
    }
   ],
   "source": [
    "# Measure FPS\n",
    "for k, v in cfg_map.items():\n",
    "    gcn = build_model(v)\n",
    "    gcn.eval()\n",
    "    start = 0\n",
    "    for i in range(warmup + iters):\n",
    "        if i == warmup:\n",
    "            start = time()\n",
    "        inp = torch.randn(batch, num_person, seq_len, num_joints, 3)\n",
    "        with torch.no_grad():\n",
    "            out = gcn(inp)\n",
    "    end = time()\n",
    "    fps = (batch * iters) / (end - start)\n",
    "    print(f'{k} FPS: {fps}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d63caa69",
   "metadata": {},
   "outputs": [],
   "source": [
    "posec3d_cfg=dict(\n",
    "    type='ResNet3dSlowOnly',\n",
    "    in_channels=17,\n",
    "    base_channels=32,\n",
    "    num_stages=3,\n",
    "    out_indices=(2, ),\n",
    "    stage_blocks=(4, 6, 3),\n",
    "    conv1_stride=(1, 1),\n",
    "    pool1_stride=(1, 1),\n",
    "    inflate=(0, 1, 1),\n",
    "    spatial_strides=(2, 2, 2),\n",
    "    temporal_strides=(1, 1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1c0060ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = 8\n",
    "warmup = 10\n",
    "iters = 10\n",
    "num_joints = 17\n",
    "seq_len = 48\n",
    "size = 56"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "42d9084b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PoseC3D FPS: 2.8657506130444377\n"
     ]
    }
   ],
   "source": [
    "posec3d = build_model(posec3d_cfg)\n",
    "posec3d.eval()\n",
    "start = 0\n",
    "for i in range(warmup + iters):\n",
    "    if i == warmup:\n",
    "        start = time()\n",
    "    inp = torch.randn(batch, num_joints, seq_len, size, size)\n",
    "    with torch.no_grad():\n",
    "        out = posec3d(inp)\n",
    "end = time()\n",
    "fps = (batch * iters) / (end - start)\n",
    "print(f'PoseC3D FPS: {fps}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "e9762a6b93aa21d4aa12033e45e3b754f20c88c84120cffc73bc7fccd60bfa55"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
